{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Zeroshot demo\n",
    "In this notebook, you can create a zeroshot classifier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import h5py\n",
    "import os\n",
    "import pickle\n",
    "import yaml\n",
    "from pathlib import Path\n",
    "from tqdm import tqdm\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "from transformers import AutoModel\n",
    "\n",
    "from titan.utils import get_eval_metrics, TEMPLATES, bootstrap\n",
    "\n",
    "os.environ[\"OMP_NUM_THREADS\"] = \"8\"\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load model from huggingface\n",
    "model = AutoModel.from_pretrained('MahmoodLab/TITAN', trust_remote_code=True)\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Single feature classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load example data\n",
    "from huggingface_hub import hf_hub_download\n",
    "demo_h5_path = hf_hub_download(\n",
    "    \"MahmoodLab/TITAN\", \n",
    "    filename=\"TCGA_demo_features/TCGA-PC-A5DK-01Z-00-DX1.C2D3BC09-411F-46CF-811B-FDBA7C2A295B.h5\",\n",
    ")\n",
    "file = h5py.File(demo_h5_path, 'r')\n",
    "features = torch.from_numpy(file['features'][:])\n",
    "coords = torch.from_numpy(file['coords'][:])\n",
    "patch_size_lv0 = file['coords'].attrs['patch_size_level0']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load configs and prompts for TCGA-OT task\n",
    "with open('../datasets/config_tcga-ot.yaml', 'r') as file:\n",
    "    task_config = yaml.load(file, Loader=yaml.FullLoader)\n",
    "class_prompts = task_config['prompts']\n",
    "target = task_config['target']\n",
    "label_dict = task_config['label_dict']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# extract slide embedding\n",
    "with torch.autocast('cuda', torch.float16), torch.inference_mode():\n",
    "    features = features.to(device)\n",
    "    coords = coords.to(device)\n",
    "    slide_embedding = model.encode_slide_from_patch_features(features, coords, patch_size_lv0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create prompts for zero-shot classification\n",
    "sorted_class_prompts = dict(sorted(class_prompts.items(), key=lambda item: label_dict.get(item[0], float('inf'))))\n",
    "classes = list(sorted_class_prompts.keys())\n",
    "class_prompts = [sorted_class_prompts[key] for key in sorted_class_prompts.keys()]\n",
    "with torch.autocast('cuda', torch.float16), torch.inference_mode():\n",
    "    classifier = model.zero_shot_classifier(class_prompts, TEMPLATES, device=device)  # will take approx 3 mins for 46 classes of TCGA-OncoTree (23 templates)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.autocast('cuda', torch.float16), torch.inference_mode():\n",
    "    scores = model.zero_shot(slide_embedding, classifier)\n",
    "print(\"Predicted class:\", classes[scores.argmax()])\n",
    "print(\"Normalized similarity scores:\", [f\"{c}: {score:.3f}\" for c, score in zip(classes, scores[0][0])])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate classifier on TCGA-OncoTree\n",
    "Reproduce the zeroshot results on the dataset TCGA-OncoTree based on pre-computed slide embeddings. The TITAN embeddings of TCGA-OT are available in our huggingface model hub."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task_csv = pd.read_csv('../datasets/tcga-ot_test.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load pre-extracted TITAN slide embeddings for TCGA\n",
    "import pickle\n",
    "from huggingface_hub import hf_hub_download\n",
    "slide_feature_path = hf_hub_download(\n",
    "    \"MahmoodLab/TITAN\", \n",
    "    filename=\"TCGA_TITAN_features.pkl\",\n",
    ")\n",
    "with open(slide_feature_path, 'rb') as file:\n",
    "  data = pickle.load(file)\n",
    "slide_embeddings = torch.from_numpy(data['embeddings'][:])\n",
    "slide_names = np.array(data['filenames'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get indices of slide_names that are in the task csv\n",
    "slide_names_series = pd.Series(slide_names)\n",
    "indices = slide_names_series[slide_names_series.isin(task_csv['slide_id'])].index\n",
    "slide_embeddings = slide_embeddings[indices]\n",
    "slide_names = slide_names[indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "probs = []\n",
    "targets = []\n",
    "\n",
    "for slide_emb, slide_id in tqdm(zip(slide_embeddings, slide_names), total=len(slide_embeddings)):\n",
    "    with torch.autocast('cuda', torch.float16), torch.inference_mode():\n",
    "        slide_emb = slide_emb.to(device)\n",
    "        probs.append(model.zero_shot(slide_emb, classifier).cpu())\n",
    "    targets.append(label_dict[task_csv[task_csv['slide_id'] == slide_id][target].values[0]])\n",
    "probs_all = torch.cat(probs, dim=0)\n",
    "targets_all = torch.tensor(targets)\n",
    "preds_all = probs_all.argmax(dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = get_eval_metrics(targets_all, preds_all, probs_all, roc_kwargs={'multi_class': 'ovo', 'average': 'macro'})\n",
    "for key, value in results.items():\n",
    "    print(f\"{key.split('/')[-1]: <12}: {value:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs = {\n",
    "    \"targets\": targets_all,\n",
    "    \"preds\": preds_all,\n",
    "    \"probs\": probs_all,\n",
    "}\n",
    "bootstrap_kwargs = {'n': 1000, 'alpha': 0.95}\n",
    "results_mean, results_std = bootstrap(results_dict=outputs, **bootstrap_kwargs)\n",
    "for keys, values in results_mean.items():\n",
    "    print(f\"{keys.split('/')[-1]: <12}: {values:.4f} ± {results_std[keys]:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "titan",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
